import time
import os
import math
from dolfin import *
import matplotlib.pyplot as plt
import matplotlib as mpl

lx = 1.0
ly = 0.3

# Create mesh and define function space
mesh = RectangleMesh(Point(0, 0), Point(lx, ly), 100, 30, 'crossed')

#plot(mesh)
#plt.show()


# Construct facet markers
bndry = MeshFunction("size_t", mesh, mesh.topology().dim()-1, 0)
for f in facets(mesh):
    mp = f.midpoint()
    if near(mp[0], 0.0): # left
        bndry[f] = 1
    elif near(mp[0], lx): # right
        bndry[f] = 2
    elif near(mp[1], 0.0): # down
        bndry[f] = 3
    elif near(mp[1], ly): # up
        bndry[f] = 4






#define function space
N = FiniteElement('P', mesh.ufl_cell(), 1)
V = VectorElement('P', mesh.ufl_cell() , 2)

element = MixedElement([N,V])
W = FunctionSpace(mesh, element)


vx_left = 1.0e1

N_fs = FunctionSpace(mesh, 'P', 1)
V_fs = VectorFunctionSpace(mesh, 'P', 2)
n00 = Expression(("1.0e14"), domain=mesh, degree=1)
v00 = Expression((vx_left, "0.0"), domain=mesh, degree=2)




#Define boundary condition
#the parameter after boundaries, see hl2a_facet_region.xml file

#the parameter after boundaries, see hl2a_facet_region.xml file
n_left_boundary = DirichletBC(W.sub(0), Constant(1.0e18), bndry, 1)
n_right_boundary = DirichletBC(W.sub(0), Constant(1.0e14), bndry, 2)
#n_down_boundary = DirichletBC(W.sub(0), Constant(1.0e14), bndry,3)
#n_up_boundary = DirichletBC(W.sub(0), Constant(1.0e14), bndry,4)

v_left_boundary = DirichletBC(W.sub(1), Constant((vx_left, 0)), bndry, 1)
v_right_boundary = DirichletBC(W.sub(1), Constant((1.0e2, 0)), bndry, 2)
#v_down_boundary = DirichletBC(W.sub(1), Constant((0, 0)), bndry,3)
#v_up_boundary = DirichletBC(W.sub(1), Constant((0, 0)), bndry,4)

bcs =[n_left_boundary, n_right_boundary, v_left_boundary, v_right_boundary]
#bcs =[n_inner_edge_boundary, n_divertor_ll_boundary, n_divertor_lr_boundary]
#bcs =[n_inner_edge_boundary, n_first_wall_boundary, n_divertor_ll_boundary, n_divertor_lr_boundary, n_dome_l_boundary, v_inner_edge_boundary, v_first_wall_boundary, v_divertor_ll_boundary, v_divertor_lr_boundary, v_dome_l_boundary]




dt = 1.0e-3
t_end = 10.0
theta=1.0   # Crank-Nicholson timestepping
k=0.01

#diffusion coeffiecent
D = 10.0
Dp = 1000.0
Dr = 1.0
mi = 1.0
niup = 1000.0
niur = 1.0
viscosity = 1.0
T = 100.0 * 1.62e-19

# Define unknown and test function(s)
(n_, v_) = TestFunctions(W)
(nt, vt) = TrialFunctions(W)


w = Function(W)
(n, v) = split(w)

w0 = Function(W)
(n0, v0) = split(w0)

#print(type(v), type(p), type(wt))
print(type(n_), type(v_))
# previous known time step
#w0 = Function(W)
#(n0, v0) = split(w0)
#(v0, p0, e0) = TrialFunctions(W)


h = 2.0 * Circumradius(mesh)
c = 1.0e-15
vnorm = sqrt(inner(v, v))
tau = h/(2.0*vnorm + c) # tau from SUPG fenics example
#tau = pow(2.0*vnorm/h + 4*c/pow(h,2.0),-1) # tau from
print(type(h), type(vnorm), type(tau))

#GLS term
C = 2.0
hh = inner(h, h)


# Define variational forms without time derivative in previous time
F0_eq1 = 0.0
F0_eq1 = F0_eq1 + div(n0 * v0) * n_ * dx
#GLS term
F0_eq1 = F0_eq1 + C * hh * div(n0 * v0) * div(n_ * v0) * dx

F1_eq1 = 0.0
F1_eq1 = F1_eq1 + div(n * v) * n_ * dx
#GLS term
F0_eq1 = F0_eq1 + C * hh * div(n * v) * div(n_ * v) * dx


F0_eq2 = 0.0
F0_eq2 = F0_eq2 + inner(nabla_grad(v0)*v0, v_) * dx
F0_eq2 = F0_eq2 + viscosity * inner(nabla_grad(v0), nabla_grad(v_))*dx

F1_eq2 = 0.0
F1_eq2 = F1_eq2 + inner(nabla_grad(v)*v, v_) * dx
F1_eq2 = F1_eq2 + viscosity * inner(nabla_grad(v), nabla_grad(v_))*dx

F0 = F0_eq1 + F0_eq2
F1 = F1_eq1 + F1_eq2

#F = (inner((n-n0),n_)/dt + inner((n * v - n0 * v0),v_)/dt)*dx + (1.0-theta)*F0  + theta*F1
F = (inner((n-n0),n_)/dt + inner((v - v0),v_)/dt)*dx + (1.0-theta)*F0  + theta*F1

J = derivative(F, w)
problem=NonlinearVariationalProblem(F,w,bcs, J)
solver=NonlinearVariationalSolver(problem)

prm = solver.parameters
info(prm,True)  #get full info on the parameters
prm['nonlinear_solver'] = 'newton'
prm['newton_solver']['absolute_tolerance'] = 1E-12
prm['newton_solver']['relative_tolerance'] = 1e-12
prm['newton_solver']['maximum_iterations'] = 200
#prm['newton_solver']['linear_solver'] = 'petsc'
#list_linear_solver_methods()

# Time-stepping
t = dt

n0 = interpolate(n00, N_fs)
v0 = interpolate(v00, V_fs)

'''
fig = plt.figure()

plt.subplot(1, 3, 1)
c = plot(n0, mode='color', cmap = mpl.cm.Spectral_r)
plt.colorbar(c)

plt.subplot(1, 3, 2)
c = plot(v0, mode='color', cmap = mpl.cm.Spectral_r)
plt.colorbar(c)

plt.subplot(1, 3, 3)
c = plot(b)

fig.savefig("nv.pdf")
plt.show()
'''

i = 0
i_out = 10

while t < t_end:

    print( "t =", t, "end t=", t_end)
    #vc.t=t

    # Compute
    begin("Solving ....")
    solver.solve()
    #solve(F == 0, w, bcs, J=J)

    # Move to next time step
    w0.assign(w)
    t += dt

    i = i + 1
    if i%i_out == 0:
        # Extract solutions:
        (n, v) = w.split()

        fig = plt.figure()

        plt.subplot(1, 2, 1)
        c = plot(n, mode='color', cmap = mpl.cm.Spectral_r)
        plt.colorbar(c)

        plt.subplot(1, 2, 2)
        c = plot(v.sub(0))
        plt.colorbar(c)


        fig.savefig("nv.pdf")
        plt.show()